#!/usr/bin/python
import code,pickle,sys,os,re
from pylab import *
from optparse import OptionParser
from ocrolib import dbtables
from ocrolib.native import *
import ocrolib

parser = OptionParser("""
usage: %prog [options] chars.db som.db mds.db

""")

parser.add_option("-D","--display",help="display chars",action="store_true")
parser.add_option("-v","--verbose",help="verbose output",action="store_true")
parser.add_option("-t","--table",help="table name",default="chars")
parser.add_option("-r","--torus",help="use toroidal topology",action="store_true")
parser.add_option("-g","--grid",help="grid size",default="10,10")
parser.add_option("-T","--threshold",help="threshold",type=float,default=0.03)
parser.add_option("-s","--scale",help="scale for iterating",type=float,default=3000.0)
parser.add_option("-n","--niter",help="max number of iterations",type=int,default=100000)

from scipy import mgrid,linalg,ndimage
import sys,os,random,math
import numpy,pylab,scipy
from numpy import *

verbose = 1

(options,args) = parser.parse_args()

if len(args)!=3:
    parser.print_help()
    sys.exit(0)

input = args[0]
output = args[1]
output2 = args[2]

def rchoose(k,n):
    assert k<=n
    return random.permutation(range(n))[:k]
def rowwise(f,data,samples=None):
    assert data.ndim==2
    if samples is None: samples = range(len(data))
    return array([f(data[i]) for i in samples])
def argmindist_slow(x,data):
    dists = [distsq(x,v) for v in data]
    return argmin(dists)
def dist(u,v):
    return linalg.norm(u-v)
def distsq(x,y):
    d = x-y
    return dot(d.ravel(),d.ravel())

native_argmin = compile_and_load(r'''
#include <math.h>
#include <stdio.h>
#include <assert.h>
#include <stdlib.h>

int argmindist0(int r,int d,float v[d],float vs[r][d]) {
    double ds[r];
#pragma omp parallel for
    for(int i=0;i<r;i++) {
        double total = 0.0;
        for(int j=0;j<d;j++) {
            float delta = v[j]-vs[i][j];
            total += delta*delta;
        }
        assert(!isnan(total));
        ds[i] = total; 
    }
    double md = ds[0];
    double mi = 0;
    for(int i=1;i<r;i++) {
        if(ds[i]>=md) continue;
        md = ds[i];
        mi = i;
    }
    return mi;
}
''')

native_argmin.argmindist0.argtypes = [I,I,A1F,A2F]
native_argmin.restype = I

def argmindist(v,data):
    v = v[:]
    assert len(v)==data.shape[1]
    return native_argmin.argmindist0(data.shape[0],data.shape[1],v,data)

def som_theta(dist,i,scale=options.scale):
    """Compute a SOM theta value used for updating.
    (This is the default; you can define your own.)"""
    so = scale
    to = scale
    sigma = 10.0 * so/(so+i)
    t = to/(to+i) * exp(-dist/2/sigma)
    if t<1e-3: return 0
    return t

def som(data,shape=None,niter=10000000,threshold=0.03,theta=som_theta,torus=0):
    assert not isnan(data).any()
    """Compute a 2D self-organizing map for the data,
    with the given shape and the maximum number of iterations.
    The theta value used for updating is computed by the theta
    function passed as an argument."""
    if shape is None:
        k = max(3,floor(data.shape[0]**(1.0/3.0)))
        shape = (k,k)
    assert shape[0]>=3 and shape[1]>=3
    w,h = shape
    n,m = data.shape
    total = w*h
    items = rchoose(total,n)
    # print items
    grid = data[items,:].copy()
    for i in range(niter):
        neighbor_update = theta(1.0,i)
        if neighbor_update<threshold: break
        best = argmindist(data[i%n],grid.reshape(w*h,m))
        x,y = best/h,best%h
        if verbose and i%1000==0:
            print i,x,y,theta(1,i)
        if theta(1,i)<1e-2: break
        for index in range(w*h):
            u,v = index/h,index%h
            dx = u-x
            dy = v-y
            if torus:
                if abs(dx)>w/2: dx = abs(dx)-w
                if abs(dy)>h/2: dy = abs(dy)-h
            d = math.hypot(dx,dy)
            t = theta(d,i)
            if t<1e-8: continue
            diff = data[i%n,:]-grid[index,:]
            grid[index,:] += t * diff
    grid.shape = (w,h,m)
    return grid

def coords0(v,grid):
    w,h,m = grid.shape
    n = w*h
    vs = grid.reshape(n,m)
    ds = zeros(n)
    for i in range(len(vs)):
        ds[i] = dist(v,vs[i])
    best = argmin(ds)
    return best/h,best%h

def coords1(v,grid):
    w,h,m = grid.shape
    n = w*h
    vs = grid.reshape(n,m)
    ds = zeros(n)
    ps = [None]*n
    for i in range(len(vs)):
        ps[i] = array([i/h,i%h])
        ds[i] = dist(v,vs[i])
    ps = array(ps)
    ds = 1.0-(ds-amin(ds))/(amax(ds)-amin(ds))
    ds = ds**2
    x = sum(ps[:,0]*ds)/sum(ds)
    y = sum(ps[:,1]*ds)/sum(ds)
    return x,y

def coords(v,grid):
    w,h,m = grid.shape
    n = w*h
    vs = grid.reshape(n,m)
    ds = array([dist(v,vs[i]) for i in range(n)])
    best = argmin(ds)
    bx,by = best/h,best%h
    bds = []
    bvs = []
    for i in range(n):
        x,y = i/h,i%h
        if abs(x-bx)>1 or abs(y-by)>1: continue
        bds.append(ds[i])
        bvs.append([x,y])
    bds = array(bds,'f')
    bvs = array(bvs,'f')
    bds = 1.0-(bds-amin(bds))/(amax(bds)-amin(bds))
    x = sum(bvs[:,0]*bds)/sum(bds)
    y = sum(bvs[:,1]*bds)/sum(bds)
    return x,y

ion()
show()

table = dbtables.Table(input,options.table)
table.converter("image",dbtables.SmallImage())
table.create(image="blob",cls="text",classes="text")

classlist = [row[0] for row in table.query("select distinct(cls) from '%s' order by cls"%
                                         options.table)]

extractor = ocrolib.ScaledFE()

data = []
classes = []
print "loading"
for cls in classlist:
    print "cls",cls
    for row in table.get(cls=cls):
        raw = row.image
        if raw.shape[0]>255 or raw.shape[1]>255: continue
        c = raw/float(amax(raw))
        v = extractor.extract(c)
        data.append(v)
        classes.append(cls)

assert len(data)>10,"not enough input data elements (maybe you need to specify a different table?)"

data = array(data)
print "clustering",data.shape
shape = eval("("+options.grid+")")
print "niter",options.niter,"scale",options.scale,"shape",shape
grid = som(data,shape,torus=options.torus,niter=options.niter)

print "writing"
table = dbtables.ClusterTable(output)
table.create(image="blob",cls="text",count="integer",classes="text",grid="text")
table.converter("image",dbtables.SmallImage())
for i in range(shape[0]):
    for j in range(shape[1]):
        v = grid[i,j]
        image = array(v/amax(v)*255.0,'B')
        image.shape = (30,30)
        table.set(image=image,cls="_",count=1,classes="",grid="%d %d"%(i,j))

print "writing2"
table = dbtables.ClusterTable(output2)
table.create(image="blob",cls="text",count="integer",classes="text",mds="text")
table.converter("image",dbtables.SmallImage())
for i in range(len(data)):
    v = data[i]
    x,y = coords(v,grid)
    image = array(v/amax(v)*255.0,'B')
    image.shape = (30,30)
    table.set(image=image,cls=classes[i],count=1,classes="",mds="%f %f"%(x,y))
